-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cuBLAS] Add cublas_gemm_batched and use cublasSetStream to set stream to the current stream in all cublas API calls #423
Conversation
e9d69bb
to
4a0007f
Compare
4a0007f
to
3d58fbd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yudi0201 !
It looks good to me.
src/hidet/runtime/cuda/cuda.cpp
Outdated
@@ -79,3 +88,18 @@ DLL void hidet_cuda_set_device(int device) { | |||
lazy_load_cuda_runtime(); | |||
CHECK_CUDA(cudaSetDevice(device)); | |||
} | |||
|
|||
DLL void hidet_cuda_malloc(void **devPtr, size_t size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider directly return the allocated memory address like
DLL void* hidet_cuda_malloc(size_t size) {
...
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like hidet_cuda_get_device(...)
.
69abe16
to
ecd1461
Compare
ecd1461
to
a446270
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yudi0201!
src/hidet/runtime/cuda/cublas.cpp
Outdated
// Allocate device memory | ||
// first use synchronous versions of malloc and memcpy, later switch to async versions | ||
if (cur_device_ptr_size != 0 && b > cur_device_ptr_size) { | ||
hidet_cuda_free((void *)ptr_a_device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not hidet_cuda_free_async
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following logic is more readable to me, just as a reference.
if(b > cur_device_ptr_size) {
if(cur_device_ptr_size > 0) {
free the three ptrs
}
alloc three ptrs
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions! I'll modify these in the next revision.
src/hidet/runtime/cuda/cuda.cpp
Outdated
|
||
DLL void* hidet_cuda_malloc(size_t size) { | ||
lazy_load_cuda_runtime(); | ||
void* devPtr = malloc(sizeof(void*)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void* devPtr = malloc(sizeof(void*)); | |
void* devPtr; |
We do not need to allocate a memory region.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm missing something, but wouldn't doing this result in devPtr being created on the stack, and then we'd be returning a local stack variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we call
cudaMallocAsync(&devPtr, ...)
we are passing the address of the pointer to the cuda api function, which will update the pointer value. The pointer is a stack variable and will be valid during we calling the cuda api function. We returns the "value" of the pointer instead of the "address" of the pointer to the callee of hidet_cuda_malloc
, this is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks!
src/hidet/runtime/cuda/cuda.cpp
Outdated
|
||
DLL void* hidet_cuda_malloc_async(size_t size, cudaStream_t stream) { | ||
lazy_load_cuda_runtime(); | ||
void* devPtr = malloc(sizeof(void*)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the previous one.
Thanks @yudi0201 ! |
No description provided.